import torch
import torch.nn.functional as F 
from torch.utils.data import DataLoader
import h5py
import os
import numpy as np
import operator
from itertools import accumulate
import copy
from utils import NodeType
from utils import calc_face_to_face_distance
import math

previous_face_edges = []
frame_index = [0]
def custom_collate(data):
    final_flag = [torch.tensor(d['final_flag']) for d in data]
    world_pos = [torch.tensor(d['world_pos']) for d in data]
    stress = [torch.tensor(d['stress']) for d in data]
    node_type = [torch.tensor(d['node_type']) for d in data]
    ground_truth = [torch.tensor(d['target']) for d in data]
    next_pos_need = [torch.tensor(d['next_pos_need']) for d in data]
    ori_faces = [torch.tensor(d["faces"]) for d in data] 
    ori_elements = [torch.tensor(d["elements"]) for d in data] 
    ori_faces_edges = [torch.tensor(d["faces_edges"]) for d in data]
    ori_faces_to_faces = [torch.tensor(d["faces_to_faces"]) for d in data]
    last_pos = [torch.tensor(d['last_pos']) for d in data]
    last_stress = [torch.tensor(d['last_stress']) for d in data]
    ori_cells_faces = [torch.tensor(d['cells_faces']) for d in data]
    ori_mesh_edge = [torch.tensor(d['mesh_edge']) for d in data]
    mesh_pos = [torch.tensor(d['mesh_pos']) for d in data]
    ori_world_edge = [torch.tensor(d['world_edge']) for d in data]

    final_flag = torch.concat(final_flag, dim = 0)
    mesh_pos = torch.concat(mesh_pos, dim = 0)
    world_pos = torch.concat(world_pos, dim = 0)
    stress = torch.concat(stress, dim = 0)
    node_type = torch.concat(node_type, dim = 0)    
    ground_truth = torch.concat(ground_truth, dim = 0)
    next_pos_need = torch.concat(next_pos_need, dim = 0)
    last_pos = torch.concat(last_pos, dim = 0)
    last_stress = torch.concat(last_stress, dim = 0)
        
    node_count = [d["stress"].shape[0] for d in data]
    element_count = [d['elements'].shape[0] for d in data]
    face_count = [d['faces'].shape[0] for d in data]
    node_count = torch.tensor([0] + list(accumulate(node_count))[:-1]) 
    element_count = torch.tensor([0] + list(accumulate(element_count))[:-1]) 
    face_count = torch.tensor([0] + list(accumulate(face_count))[:-1]) 
    for i in range(len(ori_elements)):
        ori_world_edge[i] = ori_world_edge[i] + node_count[i]
        ori_mesh_edge[i] = ori_mesh_edge[i] + node_count[i]
        ori_elements[i] = ori_elements[i] + node_count[i]
        ori_faces[i] = ori_faces[i] + node_count[i]
        ori_faces_edges[i] = ori_faces_edges[i] + element_count[i]
        ori_faces_to_faces[i] = ori_faces_to_faces[i] + face_count[i]
        ori_cells_faces[i] = ori_cells_faces[i] + face_count[i]
    elements = torch.concat(ori_elements, dim = 0)
    faces = torch.concat(ori_faces, dim = 0)
    faces_edges = torch.concat(ori_faces_edges, dim = 0)
    faces_to_faces = torch.concat(ori_faces_to_faces, dim = 0)
    cells_faces = torch.concat(ori_cells_faces, dim = 0)
    world_edge = torch.concat(ori_world_edge, dim = 0)
    mesh_edge = torch.concat(ori_mesh_edge, dim = 0)

    return  { 
                "final_flag": final_flag,
                "cells_faces":cells_faces,
                "last_stress":last_stress,
                "last_pos":last_pos,
                "world_pos": world_pos,
                "stress": stress,
                "node_type": node_type,
                "target": ground_truth,
                "next_pos_need": next_pos_need,
                "elements": elements,
                "faces": faces,
                "faces_edges": faces_edges,
                "faces_to_faces": faces_to_faces,
                "mesh_edge": mesh_edge,
                "mesh_pos": mesh_pos, 
                "world_edge": world_edge,
            }
def cells_to_edges(cells):
  edges = np.concatenate([
                     np.stack([cells[:, i], cells[:, j]], axis=1) for i in range(4) for j in range(i + 1, 4)], axis=0)
  # those edges are sometimes duplicated (within the mesh) and sometimes
  # single (at the mesh boundary).
  # sort & pack edges as single tf.int64
  receivers = np.min(edges, axis=1)
  senders = np.max(edges, axis=1)
  packed_edges = np.stack([senders, receivers], axis=1)
  # remove duplicates and unpack
  unique_rows, indices = np.unique(packed_edges, return_index=True, axis=0)
  unique_edges = unique_rows[np.argsort(indices)]
  # create two-way connectivity
  unique_edges_1 = copy.deepcopy(unique_edges)
  unique_edges_1[:,0] = copy.deepcopy(unique_edges[:,1])
  unique_edges_1[:,1] = copy.deepcopy(unique_edges[:,0])
  return np.concatenate([unique_edges, unique_edges_1], axis = 0)

def find_intersecting_pairs(B, A):

    events = []
    for idx in range(A.shape[0]):
        events.append((A[idx, 0], 'Start_A', idx))
        events.append((A[idx, 1], 'End_A', idx))
    for idx in range(B.shape[0]):
        events.append((B[idx, 0], 'Start_B', idx))
        events.append((B[idx, 1], 'End_B', idx))
    
    events.sort(key=lambda x: (x[0], {'End_A':3, 'End_B':2, 'Start_A':1, 'Start_B':0}[x[1]]))
    
    active_A, active_B = set(), set()
    result = []
    
    for time, typ, idx in events:
        if typ == 'Start_A':
            result.extend((idx, b_idx) for b_idx in active_B)
            active_A.add(idx)
        elif typ == 'Start_B':
            result.extend((a_idx, idx) for a_idx in active_A)
            active_B.add(idx)
        elif typ == 'End_A':
            active_A.discard(idx)
        else:
            active_B.discard(idx)
    
    return result
def faces_to_faces(faces, node_pos, node_type, mask, mask1):
    threshold = 0.01
    
    faces_pos = node_pos[faces.reshape(-1)].reshape(-1, 3, 3)
    na_index = np.array(np.where(mask == 1)).reshape(-1)
    na_faces_pos = faces_pos[na_index,:]
    na_min = np.min(na_faces_pos, axis = 1).reshape(-1, 3)
    na_max = np.max(na_faces_pos, axis = 1).reshape(-1, 3)
    na_faces_x_min = na_min[:,0]
    na_faces_x_max = na_max[:,0]
    na_faces_y_min = na_min[:,1]
    na_faces_y_max = na_max[:,1]
    na_faces_z_min = na_min[:,2]
    na_faces_z_max = na_max[:,2]

    nb_index = np.array(np.where(mask1 == 1)).reshape(-1)
    nb_faces_pos = faces_pos[nb_index,:]
    nb_min = np.min(nb_faces_pos, axis = 1).reshape(-1, 3)
    nb_max = np.max(nb_faces_pos, axis = 1).reshape(-1, 3)
    nb_faces_x_min = nb_min[:,0] - threshold
    nb_faces_x_max = nb_max[:,0] + threshold
    nb_faces_y_min = nb_min[:,1] - threshold
    nb_faces_y_max = nb_max[:,1] + threshold  
    nb_faces_z_min = nb_min[:,2] - threshold
    nb_faces_z_max = nb_max[:,2] + threshold

    set_A = find_intersecting_pairs(np.concatenate([na_faces_x_min.reshape(-1, 1), na_faces_x_max.reshape(-1, 1)], axis = 1), 
                                    np.concatenate([nb_faces_x_min.reshape(-1, 1), nb_faces_x_max.reshape(-1, 1)], axis = 1))
    set_B = find_intersecting_pairs(np.concatenate([na_faces_y_min.reshape(-1, 1), na_faces_y_max.reshape(-1, 1)], axis = 1), 
                                    np.concatenate([nb_faces_y_min.reshape(-1, 1), nb_faces_y_max.reshape(-1, 1)], axis = 1))    
    set_C = find_intersecting_pairs(np.concatenate([na_faces_z_min.reshape(-1, 1), na_faces_z_max.reshape(-1, 1)], axis = 1), 
                                    np.concatenate([nb_faces_z_min.reshape(-1, 1), nb_faces_z_max.reshape(-1, 1)], axis = 1)) 
    set_A = set(set_A)
    set_B = set(set_B)
    results = []
    for item in set_C:
        if (item in set_A):
            if (item in set_B):
                results.append(item)
    #print(results)
    results = np.array(results)
    results[:, 0] = nb_index[results[:, 0]]
    results[:, 1] = na_index[results[:, 1]]   
    edge_index = results 

    frame_index.append(frame_index[-1] + edge_index.shape[1])
    previous_face_edges.append(edge_index)
    return edge_index
                
def nodes_to_edges(nodes, node_type):
    threshold = 0.03
    mask = np.logical_or(node_type == NodeType.NORMAL, node_type == NodeType.HANDLE)
    na_index = np.array(np.where(mask == 1))[0,:].reshape(-1)
    node_a = nodes[mask.reshape(-1),:]
    mask = (node_type == NodeType.OBSTACLE)
    nb_index = np.array(np.where(mask == 1))[0,:].reshape(-1)
    node_b = nodes[mask.reshape(-1),:]
    
    distances = np.sqrt(np.sum((node_a - node_b[:, None, :] )** 2, axis=-1))
    edge_index = np.array(np.where(distances < threshold))
    edge_index[0,:] = nb_index[edge_index[0,:]]
    edge_index[1,:] = na_index[edge_index[1,:]]
    unique_edges = np.transpose(edge_index)
    unique_edges_1 = copy.deepcopy(unique_edges)
    unique_edges_1[:,0] = copy.deepcopy(unique_edges[:,1])
    unique_edges_1[:,1] = copy.deepcopy(unique_edges[:,0])
    return np.concatenate([unique_edges, unique_edges_1], axis = 0)

def cells_to_faces(cells):
    faces = np.concatenate([np.stack([cells[:, 0], cells[:, 1], cells[:, 2]], axis = 1),
                            np.stack([cells[:, 0], cells[:, 1], cells[:, 3]], axis = 1),
                            np.stack([cells[:, 0], cells[:, 2], cells[:, 3]], axis = 1),
                            np.stack([cells[:, 1], cells[:, 2], cells[:, 3]], axis = 1),], axis = 0)
    cells_faces = []
    for i in range(cells.shape[0]):
        cells_faces.append([])
    faces = np.sort(faces, axis = 1)
    index = np.zeros((faces.shape[0], 1))   
    for i in range(faces.shape[0]):
        index[i, 0] = i % (faces.shape[0] // 4)
    new_faces = np.concatenate([faces, index], axis = 1)
    faces_list = [list(new_faces[i]) for i in range(faces.shape[0])]
    faces_list.sort()
    faces_edges = []
    unique_faces_list = []
    for i in range(faces.shape[0]):
        if (i == 0 or faces_list[i][:-1] != faces_list[i - 1][:-1]):
            faces_edges.append([faces_list[i][-1]])
            unique_faces_list.append(faces_list[i][:-1])
        else:
             faces_edges[-1].append(faces_list[i][-1])    
        cells_faces[int(faces_list[i][-1])].append(len(unique_faces_list) - 1)
    for i in range(len(faces_edges)):
        if (len(faces_edges[i]) == 1):
            faces_edges[i].append(-1e9)     

    return np.array(unique_faces_list), np.array(faces_edges), np.array(cells_faces)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def node_sorted(x, pos):
    item_size = x.shape[0]
    node_size = x.shape[1]  
    x_pos = pos[x.reshape(-1)].reshape(item_size, -1, 3)
    x_mean = np.mean(x_pos, axis = 1)
    w_x = copy.deepcopy(x)
    for i in range(node_size):
        w_x[:, i] = np.sqrt(np.sum((x_pos[:, i, :] - x_mean[:])**2, axis = 1)).reshape(-1)
    row_indices = np.argsort(w_x, axis=1)
    sorted_arr = x[np.arange(x.shape[0])[:, None], row_indices]
    return sorted_arr

def face_sorted(cells_faces, faces, pos):
    item_size = cells_faces.shape[0]
    face_size = cells_faces.shape[1]  
    
    faces_size = faces.shape[0]
    faces_pos = pos[faces.reshape(-1)].reshape(faces_size, -1, 3)
    faecs_mean = np.mean(faces_pos, axis = 1)
    
    cells_pos = faecs_mean[cells_faces.reshape(-1)].reshape(item_size, -1, 3)
    cells_mean = np.mean(cells_pos, axis = 1)
    w_x = copy.deepcopy(cells_faces)
    for i in range(face_size):
        w_x[:, i] = np.sqrt(np.sum((cells_pos[:, i, :] - cells_mean[:])**2, axis = 1)).reshape(-1)
    row_indices = np.argsort(w_x, axis=1)
    sorted_arr = cells_faces[np.arange(item_size)[:, None], row_indices]
    return sorted_arr

def GetDataSetDP(dataset_dir, split, batch_size = 1):
    file_path = os.path.join(dataset_dir, split + ".h5")
    frame_data_collected = []
    cnt = 0
    cnter = 0

    output_file_path = f"/meshgraphnet_data/deepmind_h5/deforming_plate/data_{split}_trace99_trial.h5"
    with h5py.File(output_file_path, 'w') as f:
        with h5py.File(file_path, 'r') as file:
            for trace_id in file.keys():
                print(trace_id)
                group_1 = f.create_group(str(trace_id))
                trace_data = file[trace_id]
                trace_length = trace_data["stress"].shape[0]
                
                cells_edges = cells_to_edges(trace_data["cells"][0])
                faces, faces_edges, cells_faces = cells_to_faces(trace_data["cells"][0])
                faces = faces.astype(int)
                is_surface = np.zeros(faces.shape[0]).astype(int)
                for i in range(faces.shape[0]):
                    is_surface[i] = faces_edges[i][-1] < 0
                mask = np.zeros(faces.shape[0]).astype(int)
                for i in range((faces.shape[0])):
                    mask[i] = (trace_data["node_type"][0][faces[i, 0]] == NodeType.NORMAL or trace_data["node_type"][0][faces[i, 0]] == NodeType.HANDLE) and is_surface[i]
                mask1 = np.zeros(faces.shape[0]).astype(int)
                for i in range((faces.shape[0])):
                    mask1[i] = trace_data["node_type"][0][faces[i, 0]] == NodeType.OBSTACLE and is_surface[i] 
                step_slice = 4
                last_cnt = cnt
                real_slice = 4
                if (split == "test"):
                    real_slice = step_slice
                for frame_id in range(0, trace_length - step_slice, real_slice):
                    group_2 = group_1.create_group(str(frame_id))
                    next_pos_need = trace_data["pos"][frame_id + step_slice]
                    next_pos_need[(trace_data["node_type"][frame_id] == NodeType.NORMAL).reshape(-1), :] = 0
                    last_pos, last_stress = 0, 0
                    if (frame_id > 0):
                        last_pos, last_stress = trace_data["pos"][frame_id - step_slice], trace_data["stress"][frame_id - step_slice]
                    else:
                        last_pos, last_stress = trace_data["pos"][frame_id], trace_data["stress"][frame_id]
                    new_faces = node_sorted(copy.deepcopy(faces), trace_data["pos"][frame_id])
                    need_to_store = { 
                        "final_flag": np.array([((frame_id + real_slice) >= (trace_length - step_slice))]),
                        "mesh_pos": trace_data["mesh_pos"][frame_id],
                        "mesh_edge": copy.deepcopy(cells_edges),
                        "cells_faces": face_sorted(copy.deepcopy(cells_faces), new_faces, trace_data["pos"][frame_id]),
                        "last_stress": last_stress,
                        "last_pos": last_pos,
                        "world_pos": trace_data["pos"][frame_id],
                        "stress": trace_data["stress"][frame_id],
                        "node_type": trace_data["node_type"][frame_id],
                        "elements": node_sorted(trace_data["cells"][0], trace_data["pos"][frame_id]),
                        "faces": new_faces,
                        "world_edge": nodes_to_edges(trace_data["pos"][frame_id], trace_data["node_type"][frame_id]),
                        "faces_edges": copy.deepcopy(faces_edges),
                        "next_pos_need": next_pos_need,
                        "faces_to_faces": faces_to_faces(faces, trace_data["pos"][frame_id], trace_data["node_type"][frame_id], mask, mask1),
                        "target": np.concatenate([trace_data["pos"][frame_id + step_slice], trace_data["stress"][frame_id + step_slice] - trace_data["stress"][frame_id]], axis=1)
                    }
                    for name, value in need_to_store.items():
                        group_2.create_dataset(name, data=value)
                    cnt = cnt + step_slice
                    cnt = last_cnt + trace_length - 1
                cnter = cnter + 1

    dataset = DataLoader(dataset = frame_data_collected, batch_size = batch_size, shuffle = False, collate_fn=custom_collate)
    return dataset
if __name__ == '__main__':
    dataset = GetDataSetDP("/meshgraphnet_data/deepmind_h5/deforming_plate/", "test")
    
    
